{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "aaed9cbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import task\n",
    "import deit\n",
    "import trocr_models\n",
    "import torch\n",
    "import fairseq\n",
    "from fairseq import utils\n",
    "from fairseq_cli import generate\n",
    "from PIL import Image\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "\n",
    "def init(model_path, beam=5):\n",
    "    model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(\n",
    "        [model_path],\n",
    "        arg_overrides={\"beam\": beam, \"task\": \"text_recognition\", \"data\": \"\", \"fp16\": False})\n",
    "\n",
    "    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "    model[0].to(device)\n",
    "\n",
    "    img_transform = transforms.Compose([\n",
    "        transforms.Resize((384, 384), interpolation=3),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(0.5, 0.5)\n",
    "    ])\n",
    "\n",
    "    generator = task.build_generator(\n",
    "        model, cfg.generation, extra_gen_cls_kwargs={'lm_model': None, 'lm_weight': None}\n",
    "    )\n",
    "\n",
    "    bpe = task.build_bpe(cfg.bpe)\n",
    "\n",
    "    return model, cfg, task, generator, bpe, img_transform, device\n",
    "\n",
    "\n",
    "def preprocess(img_path, img_transform):\n",
    "    im = Image.open(img_path).convert('RGB').resize((384, 384))\n",
    "    im = img_transform(im).unsqueeze(0).to(device).float()\n",
    "\n",
    "    sample = {\n",
    "        'net_input': {\"imgs\": im},\n",
    "    }\n",
    "\n",
    "    return sample\n",
    "\n",
    "\n",
    "def get_text(cfg, generator, model, sample, bpe):\n",
    "    decoder_output = task.inference_step(generator, model, sample, prefix_tokens=None, constraints=None)\n",
    "    decoder_output = decoder_output[0][0]       #top1\n",
    "\n",
    "    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(\n",
    "        hypo_tokens=decoder_output[\"tokens\"].int().cpu(),\n",
    "        src_str=\"\",\n",
    "        alignment=decoder_output[\"alignment\"],\n",
    "        align_dict=None,\n",
    "        tgt_dict=model[0].decoder.dictionary,\n",
    "        remove_bpe=cfg.common_eval.post_process,\n",
    "        extra_symbols_to_ignore=generate.get_symbols_to_strip_from_output(generator),\n",
    "    )\n",
    "\n",
    "    detok_hypo_str = bpe.decode(hypo_str)\n",
    "\n",
    "    return detok_hypo_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b95c01e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = 'path/to/model'\n",
    "jpg_path = \"path/to/pic\"\n",
    "beam = 5\n",
    "\n",
    "model, cfg, task, generator, bpe, img_transform, device = init(model_path, beam)\n",
    "\n",
    "sample = preprocess(jpg_path, img_transform)\n",
    "\n",
    "text = get_text(cfg, generator, model, sample, bpe)\n",
    "\n",
    "print(text)\n",
    "\n",
    "print('done')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.5 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "0b8488e5f98ef3932f4ff0893213e55e6ba8b00dde307078d0f3efb25017ce11"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
